{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6bYaCABobL5q"
      },
      "source": [
        "##### Copyright 2018 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "FlUw7tSKbtg4"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xc1srSc51n_4"
      },
      "source": [
        "# SavedModel 形式の使用"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-nBUqG2rchGH"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td><a target=\"_blank\" href=\"https://www.tensorflow.org/guide/saved_model\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\">TensorFlow.org で表示</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs-l10n/blob/master/site/ja/guide/saved_model.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Google Colab で実行</a></td>\n",
        "  <td><a target=\"_blank\" href=\"https://github.com/tensorflow/docs-l10n/blob/master/site/ja/guide/saved_model.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">GitHub でソースを表示</a></td>\n",
        "  <td><a href=\"https://storage.googleapis.com/tensorflow_docs/docs-l10n/site/ja/guide/saved_model.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\">ノートブックをダウンロード</a></td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CPE-fshLTsXU"
      },
      "source": [
        "SavedModel には、重みや計算を含む完全な TensorFlow プログラムが含まれます。実行するために元のモデルのビルディングコードを必要としないため、共有やデプロイに便利です（[TFLite](https://tensorflow.org/lite)、[TensorFlow.js](https://js.tensorflow.org/), [TensorFlow Serving](https://www.tensorflow.org/tfx/serving/tutorials/Serving_REST_simple)、または [TensorFlow Hub](https://tensorflow.org/hub)）。\n",
        "\n",
        "このドキュメントでは、低レベル `tf.saved_model` API の使用方法に関する詳細の一部を説明します。\n",
        "\n",
        "- `tf.keras.Model` を使用している場合、`keras.Model.save(output_path)` メソッドのみが必要である可能性があります。「[Keras の保存とシリアライズ](keras/save_and_serialize.ipynb)」を参照してください。\n",
        "\n",
        "- トレーニング中の重みの保存/読み込みのみを実行する場合は、「[トレーニングチェックポイントガイド](./checkpoint.ipynb)」を参照してください。\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9SuIC7FiI9g8"
      },
      "source": [
        "## Keras を使った SavedModel の作成\n",
        "\n",
        "簡単な導入として、このセクションでは、事前にトレーニング済みの Keras モデルをエクスポートし、それを使って画像分類リクエストを送信します。SavedModels のほかの作成方法については、このガイドの残りの部分で説明します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Le5OB-fBHHW7"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import tempfile\n",
        "\n",
        "from matplotlib import pyplot as plt\n",
        "import numpy as np\n",
        "import tensorflow as tf\n",
        "\n",
        "tmpdir = tempfile.mkdtemp()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wlho4HEWoHUT"
      },
      "outputs": [],
      "source": [
        "physical_devices = tf.config.experimental.list_physical_devices('GPU')\n",
        "if physical_devices:\n",
        "  tf.config.experimental.set_memory_growth(physical_devices[0], True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SofdPKo0G8Lb"
      },
      "outputs": [],
      "source": [
        "file = tf.keras.utils.get_file(\n",
        "    \"grace_hopper.jpg\",\n",
        "    \"https://storage.googleapis.com/download.tensorflow.org/example_images/grace_hopper.jpg\")\n",
        "img = tf.keras.preprocessing.image.load_img(file, target_size=[224, 224])\n",
        "plt.imshow(img)\n",
        "plt.axis('off')\n",
        "x = tf.keras.preprocessing.image.img_to_array(img)\n",
        "x = tf.keras.applications.mobilenet.preprocess_input(\n",
        "    x[tf.newaxis,...])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sqVcFL10JkF0"
      },
      "source": [
        "実行例として、グレース・ホッパーの画像と、使いやすいため、Keras の次元トレーニング済み画像分類モデルを使用します。カスタムモデルも使用できますが、これについては後半で説明します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JhVecdzJTsKE"
      },
      "outputs": [],
      "source": [
        "labels_path = tf.keras.utils.get_file(\n",
        "    'ImageNetLabels.txt',\n",
        "    'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')\n",
        "imagenet_labels = np.array(open(labels_path).read().splitlines())"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aEHSYjW6JZHV"
      },
      "outputs": [],
      "source": [
        "pretrained_model = tf.keras.applications.MobileNet()\n",
        "result_before_save = pretrained_model(x)\n",
        "\n",
        "decoded = imagenet_labels[np.argsort(result_before_save)[0,::-1][:5]+1]\n",
        "\n",
        "print(\"Result before saving:\\n\", decoded)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r4KIsQDZJ5PS"
      },
      "source": [
        "この画像の予測トップは「軍服」です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8nfznDmHCW6F"
      },
      "outputs": [],
      "source": [
        "mobilenet_save_path = os.path.join(tmpdir, \"mobilenet/1/\")\n",
        "tf.saved_model.save(pretrained_model, mobilenet_save_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pyX-ETE3wX63"
      },
      "source": [
        "save-path は、TensorFlow Serving が使用する規則に従っており、最後のパスコンポーネント（この場合 `1/`）はモデルのバージョンを指します。Tensorflow Serving のようなツールで、相対的な鮮度を区別させることができます。\n",
        "\n",
        "`tf.saved_model.load` で SavedModel を Python に読み込み直し、ホッパー将官の画像がどのように分類されるかを確認できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NP2UpVFRV7N_"
      },
      "outputs": [],
      "source": [
        "loaded = tf.saved_model.load(mobilenet_save_path)\n",
        "print(list(loaded.signatures.keys()))  # [\"serving_default\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K5srGzowfWff"
      },
      "source": [
        "インポートされるシグネチャは、必ずディクショナリを返します。シグネチャ名と出力ディクショナリキーをカスタマイズするには、「[エクスポート中のシグネチャの指定](#specifying_signatures_during_export)」を参照してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ChFLpegYfQGR"
      },
      "outputs": [],
      "source": [
        "infer = loaded.signatures[\"serving_default\"]\n",
        "print(infer.structured_outputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cJYyZnptfuru"
      },
      "source": [
        "SavedModel から推論を実行すると、元のモデルと同じ結果が得られます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9WjGEaS3XfX7"
      },
      "outputs": [],
      "source": [
        "labeling = infer(tf.constant(x))[pretrained_model.output_names[0]]\n",
        "\n",
        "decoded = imagenet_labels[np.argsort(labeling)[0,::-1][:5]+1]\n",
        "\n",
        "print(\"Result after saving and loading:\\n\", decoded)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SJEkdXjTWbtl"
      },
      "source": [
        "## TensorFlow Serving での SavedModel の実行\n",
        "\n",
        "SavedModels は Python から使用可能（詳細は以下参照）ですが、本番環境では通常、Python コードを使用せずに、推論専用のサービスが使用されます。これは、TensorFlow Serving を使用して SavedModel から簡単にセットアップできます。\n",
        "\n",
        "エンドツーエンドのtensorflow-servingの例については、 [TensorFlow Serving RESTチュートリアル](https://www.tensorflow.org/tfx/tutorials/serving/rest_simple)をご覧ください。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Bi0ILzu1XdWw"
      },
      "source": [
        "## ディスク上の SavedModel 形式\n",
        "\n",
        "SavedModel は、変数の値や語彙など、シリアル化されたシグネチャとそれらを実行するために必要な状態を含むディレクトリです。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6u3YZuYZXyTO"
      },
      "outputs": [],
      "source": [
        "!ls {mobilenet_save_path}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ple4X5utX8ue"
      },
      "source": [
        "`saved_model.pb` ファイルは、実際の TensorFlow プログラムまたはモデル、およびテンソル入力を受け入れてテンソル出力を生成する関数を識別する一連の名前付きシグネチャを保存します。\n",
        "\n",
        "SavedModel には、複数のモデルバリアント（`saved_model_cli` への `--tag_set` フラグで識別される複数の `v1.MetaGraphDefs`）が含まれることがありますが、それは稀なことです。複数のモデルバリアントを作成する API には、[`tf.Estimator.experimental_export_all_saved_models`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#experimental_export_all_saved_models) と TensorFlow 1.x の `tf.saved_model.Builder` があります。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pus0dOYTYXbI"
      },
      "outputs": [],
      "source": [
        "!saved_model_cli show --dir {mobilenet_save_path} --tag_set serve"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eALHpGvRZOhk"
      },
      "source": [
        "`variables` ディレクトリには、標準のトレーニングチェックポイントが含まれます（「[トレーニングチェックポイントガイド](./checkpoint.ipynb)」を参照してください）。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EDYqhDlNZAC2"
      },
      "outputs": [],
      "source": [
        "!ls {mobilenet_save_path}/variables"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VKmaZQpHahGh"
      },
      "source": [
        "`assets` ディレクトリには、語彙テーブルを初期化するためのテキストファイルなど、TensorFlow グラフが使用するファイルが含まれます。この例では使用されません。\n",
        "\n",
        "SavedModel には、SavedModel で何をするかといった消費者向けの情報など、TensorFlow グラフで使用されないファイルに使用する `assets.extra` ディレクトリがある場合があります。TensorFlow そのものでは、このディレクトリは使用されません。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zIceoF_CYmaF"
      },
      "source": [
        "## カスタムモデルの保存\n",
        "\n",
        "`tf.saved_model.save` は、`tf.Module` オブジェクトと、`tf.keras.Layer` や `tf.keras.Model` などのサブクラスの保存をサポートしています。\n",
        "\n",
        "`tf.Module` の保存と復元の例を見てみましょう。\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6EPvKiqXMm3d"
      },
      "outputs": [],
      "source": [
        "class CustomModule(tf.Module):\n",
        "\n",
        "  def __init__(self):\n",
        "    super(CustomModule, self).__init__()\n",
        "    self.v = tf.Variable(1.)\n",
        "\n",
        "  @tf.function\n",
        "  def __call__(self, x):\n",
        "    print('Tracing with', x)\n",
        "    return x * self.v\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])\n",
        "  def mutate(self, new_v):\n",
        "    self.v.assign(new_v)\n",
        "\n",
        "module = CustomModule()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "J4FcP-Co3Fnw"
      },
      "source": [
        "`tf.Module` を保存すると、すべての `tf.Variable` 属性、`tf.function` でデコレートされたメソッド、および再帰トラバースで見つかった `tf.Module` が保存されます（この再帰トラバースについては、「[チェックポイントのチュートリアル](./checkpoint.ipynb)」を参照してください）。ただし、Python の属性、関数、およびデータは失われます。つまり、`tf.function` が保存されても、Python コードは保存されません。\n",
        "\n",
        "Python コードが保存されないのであれば、SavedModel は関数をどのようにして復元するのでしょうか。\n",
        "\n",
        "簡単に言えば、`tf.function` は、Python コードをトレースして ConcreteFunction（`tf.Graph` のコーラブルラッパー）を生成することで機能します。`tf.function` を保存すると、実際には `tf.function` の ConcreteFunctions のキャッシュを保存しているのです。\n",
        "\n",
        "`tf.function` と ConcreteFunctions の関係に関する詳細は、「[tf.function ガイド](../../guide/function)」を参照してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "85PUO9iWH7xn"
      },
      "outputs": [],
      "source": [
        "module_no_signatures_path = os.path.join(tmpdir, 'module_no_signatures')\n",
        "module(tf.constant(0.))\n",
        "print('Saving model...')\n",
        "tf.saved_model.save(module, module_no_signatures_path)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2ujwmMQg7OUo"
      },
      "source": [
        "## カスタムモデルの読み込みと使用"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QpxQy5Eb77qJ"
      },
      "source": [
        "Python に SavedModel を読み込むと、すべての `tf.Variable` 属性、`tf.function` でデコレートされたメソッド、および `tf.Module` は、保存された元の `tf.Module` と同じオブジェクト構造で復元されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EMASjADPxPso"
      },
      "outputs": [],
      "source": [
        "imported = tf.saved_model.load(module_no_signatures_path)\n",
        "assert imported(tf.constant(3.)).numpy() == 3\n",
        "imported.mutate(tf.constant(2.))\n",
        "assert imported(tf.constant(3.)).numpy() == 6"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CDiauvb_99uk"
      },
      "source": [
        "Python コードは保存されないため、新しい入力シグネチャで `tf.function` で呼び出しても失敗します。\n",
        "\n",
        "```python\n",
        "imported(tf.constant([3.]))\n",
        "```\n",
        "\n",
        "<pre>ValueError: Could not find matching function to call for canonicalized inputs ((<tf.tensor shape=\"(1,)\" dtype=\"float32\">,), {}). Only existing signatures are [((TensorSpec(shape=(), dtype=tf.float32, name=u'x'),), {})]. </tf.tensor></pre>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4Vsva3UZ-2sf"
      },
      "source": [
        "### 基本の微調整\n",
        "\n",
        "変数オブジェクトが提供されており、インポートされた関数をバックプロパゲーションできます。単純な事例で SavedModel を微調整（再トレーニング）するには、これで十分です。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PEkQNarJ-7nT"
      },
      "outputs": [],
      "source": [
        "optimizer = tf.optimizers.SGD(0.05)\n",
        "\n",
        "def train_step():\n",
        "  with tf.GradientTape() as tape:\n",
        "    loss = (10. - imported(tf.constant(2.))) ** 2\n",
        "  variables = tape.watched_variables()\n",
        "  grads = tape.gradient(loss, variables)\n",
        "  optimizer.apply_gradients(zip(grads, variables))\n",
        "  return loss"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p41NM6fF---3"
      },
      "outputs": [],
      "source": [
        "for _ in range(10):\n",
        "  # \"v\" approaches 5, \"loss\" approaches 0\n",
        "  print(\"loss={:.2f} v={:.2f}\".format(train_step(), imported.v.numpy()))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XuXtkHSD_KSW"
      },
      "source": [
        "### 一般的な微調整\n",
        "\n",
        "Keras の SavedModel は、より高度な微調整の事例に対処できる、プレーンな `__call__` よりも[詳細な内容](https://github.com/tensorflow/community/blob/master/rfcs/20190509-keras-saved-model.md#serialization-details)を提供します。TensorFlow Hub は、微調整の目的で共有される SavedModel に、該当する場合は次の項目を提供することをお勧めします。\n",
        "\n",
        "- モデルに、フォワードパスがトレーニングと推論で異なるドロップアウトまたはほかのテクニックが使用されている場合（バッチの正規化など）、`__call__` メソッドは、オプションのPython 重視の `training=` 引数を取ります。この引数は、デフォルトで `False` になりますが、`True` に設定することができます。\n",
        "- `__call__` 属性の隣には、対応する変数リストを伴う `.variable` と `.trainable_variable` 属性があります。もともとトレーニング可能であっても、微調整中には凍結されるべき変数は、`.trainable_variables` から省略されます。\n",
        "- レイヤとサブモデルの属性として重みの正規化を表現する Keras のようなフレームワークのために、`.regularization_losses` 属性も使用できます。この属性は、値が合計損失に追加することを目的とした引数無しの関数のリストを保有します。\n",
        "\n",
        "最初の MobileNet の例に戻ると、これらの一部が動作していることを確認できます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y6EUFdY8_PRD"
      },
      "outputs": [],
      "source": [
        "loaded = tf.saved_model.load(mobilenet_save_path)\n",
        "print(\"MobileNet has {} trainable variables: {}, ...\".format(\n",
        "          len(loaded.trainable_variables),\n",
        "          \", \".join([v.name for v in loaded.trainable_variables[:5]])))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B-mQJ8iP_R0h"
      },
      "outputs": [],
      "source": [
        "trainable_variable_ids = {id(v) for v in loaded.trainable_variables}\n",
        "non_trainable_variables = [v for v in loaded.variables\n",
        "                           if id(v) not in trainable_variable_ids]\n",
        "print(\"MobileNet also has {} non-trainable variables: {}, ...\".format(\n",
        "          len(non_trainable_variables),\n",
        "          \", \".join([v.name for v in non_trainable_variables[:3]])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qGlHlbd3_eyO"
      },
      "source": [
        "## エクスポート中のシグネチャの指定\n",
        "\n",
        "TensorFlow Serving や `saved_model_cli` のようなツールは、SavedModel と対話できます。これらのツールがどの ConcreteFunctions を使用するか判定できるように、サービングシグネチャを指定する必要があります。`tf.keras.Model` は、サービングシグネチャを自動的に指定しますが、カスタムモジュールに対して明示的に宣言する必要があります。\n",
        "\n",
        "デフォルトでは、シグネチャはカスタム `tf.Module` で宣言されません。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h-IB5Xa0NxLa"
      },
      "outputs": [],
      "source": [
        "assert len(imported.signatures) == 0"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BiNtaMZSI8Tb"
      },
      "source": [
        "サービングシグネチャを宣言するには、`signatures` kwarg を使用して ConcreteFunction 指定します。単一のシグネチャを指定する場合、シグネチャキーは `'serving_default'` となり、定数 `tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY` として保存されます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_pAdgIORR2yH"
      },
      "outputs": [],
      "source": [
        "module_with_signature_path = os.path.join(tmpdir, 'module_with_signature')\n",
        "call = module.__call__.get_concrete_function(tf.TensorSpec(None, tf.float32))\n",
        "tf.saved_model.save(module, module_with_signature_path, signatures=call)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nAzRHR0UT4hv"
      },
      "outputs": [],
      "source": [
        "imported_with_signatures = tf.saved_model.load(module_with_signature_path)\n",
        "list(imported_with_signatures.signatures.keys())\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_gH91j1IR4tq"
      },
      "source": [
        "複数のシグネチャをエクスポートするには、シグネチャキーのディクショナリを ConcreteFunction に渡します。各シグネチャキーは 1 つの ConcreteFunction に対応します。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6VYAiQmLUiox"
      },
      "outputs": [],
      "source": [
        "module_multiple_signatures_path = os.path.join(tmpdir, 'module_with_multiple_signatures')\n",
        "signatures = {\"serving_default\": call,\n",
        "              \"array_input\": module.__call__.get_concrete_function(tf.TensorSpec([None], tf.float32))}\n",
        "\n",
        "tf.saved_model.save(module, module_multiple_signatures_path, signatures=signatures)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8IPx_0RWEx07"
      },
      "outputs": [],
      "source": [
        "imported_with_multiple_signatures = tf.saved_model.load(module_multiple_signatures_path)\n",
        "list(imported_with_multiple_signatures.signatures.keys())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "43_Qv2W_DJZZ"
      },
      "source": [
        "デフォルトでは、出力されたテンソル名は、`output_0` というようにかなり一般的な名前です。出力の名前を制御するには、出力名を出力にマッピングするディクショナリを返すように `tf.function` を変更します。入力の名前は Python 関数の引数名から取られます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ACKPl1X8G1gw"
      },
      "outputs": [],
      "source": [
        "class CustomModuleWithOutputName(tf.Module):\n",
        "  def __init__(self):\n",
        "    super(CustomModuleWithOutputName, self).__init__()\n",
        "    self.v = tf.Variable(1.)\n",
        "\n",
        "  @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])\n",
        "  def __call__(self, x):\n",
        "    return {'custom_output_name': x * self.v}\n",
        "\n",
        "module_output = CustomModuleWithOutputName()\n",
        "call_output = module_output.__call__.get_concrete_function(tf.TensorSpec(None, tf.float32))\n",
        "module_output_path = os.path.join(tmpdir, 'module_with_output_name')\n",
        "tf.saved_model.save(module_output, module_output_path,\n",
        "                    signatures={'serving_default': call_output})"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1yGVy4MuH-V0"
      },
      "outputs": [],
      "source": [
        "imported_with_output_name = tf.saved_model.load(module_output_path)\n",
        "imported_with_output_name.signatures['serving_default'].structured_outputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dk5wWyuMpuHx"
      },
      "source": [
        "## Estimator の SavedModel\n",
        "\n",
        "Estimator は、[`tf.Estimator.export_saved_model`](https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator#export_saved_model) によって SavedModel をエクスポートします。詳細は、「[Estimator ガイド](https://www.tensorflow.org/guide/estimator)」を参照してください。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B9KQq5qzpzbK"
      },
      "outputs": [],
      "source": [
        "input_column = tf.feature_column.numeric_column(\"x\")\n",
        "estimator = tf.estimator.LinearClassifier(feature_columns=[input_column])\n",
        "\n",
        "def input_fn():\n",
        "  return tf.data.Dataset.from_tensor_slices(\n",
        "    ({\"x\": [1., 2., 3., 4.]}, [1, 1, 0, 0])).repeat(200).shuffle(64).batch(16)\n",
        "estimator.train(input_fn)\n",
        "\n",
        "serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(\n",
        "  tf.feature_column.make_parse_example_spec([input_column]))\n",
        "estimator_base_path = os.path.join(tmpdir, 'from_estimator')\n",
        "estimator_path = estimator.export_saved_model(estimator_base_path, serving_input_fn)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XJ4PJ-Cl4060"
      },
      "source": [
        "SavedModel は、シリアル化された `tf.Example` プロトコルバッファを受け取るため、提供する上で役立ちます。ただし、Python から `tf.saved_model.load` で読み込んで実行することもできます。"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c_BUBBNB1UH9"
      },
      "outputs": [],
      "source": [
        "imported = tf.saved_model.load(estimator_path)\n",
        "\n",
        "def predict(x):\n",
        "  example = tf.train.Example()\n",
        "  example.features.feature[\"x\"].float_list.value.extend([x])\n",
        "  return imported.signatures[\"predict\"](\n",
        "    examples=tf.constant([example.SerializeToString()]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "C1ylWZCQ1ahG"
      },
      "outputs": [],
      "source": [
        "print(predict(1.5))\n",
        "print(predict(3.5))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_IrCCm0-isqA"
      },
      "source": [
        "`tf.estimator.export.build_raw_serving_input_receiver_fn` を使用すると、`tf.train.Example` の代わりに生のテンソルを取る入力関数を作成することができます。"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Co6fDbzw_UnD"
      },
      "source": [
        "## C++ による SavedModel の読み込み\n",
        "\n",
        "C++ バージョンの SavedModel [ローダー](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/cc/saved_model/loader.h)には、SessionOptions と RunOptions を許可しながら、パスから SavedModel を読み込む API が提供されています。読み込まれるグラフに関連付けられたタグを指定する必要があります。読み込まれた SavedModel は SavedModelBundle と呼ばれ、その中に MetaGraphDef とセッションが含まれます。\n",
        "\n",
        "```C++\n",
        "const string export_dir = ... SavedModelBundle bundle; ... LoadSavedModel(session_options, run_options, export_dir, {kSavedModelTagTrain},                &bundle);\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b33KuyEuAO3Z"
      },
      "source": [
        "<a id=\"saved_model_cli\"></a>\n",
        "\n",
        "## SavedModel コマンドラインインターフェースの詳細\n",
        "\n",
        "SavedModel コマンドラインインターフェース（CLI）を使用して、SavedModel のインスペクションと実行を行えます。たとえば、CLI を使用してモデルの `SignatureDef` を検査できます。CLI を使用することで、入力テンソルの dtype と形状がモデルに一致していることを素早く確認することが可能になります。さらに、モデルをテストする場合は、CLI でサンプル入力をさまざまな形式に渡し（Python 式など）て出力をフェッチすることで、サニティーチェックを実施できます。\n",
        "\n",
        "### SavedModel CLI をインストールする\n",
        "\n",
        "大まかに、次の 2 つの方法のいずれかを使って、TensorFlow をインストールできます。\n",
        "\n",
        "- ビルド済みの TensorFlow バイナリをインストールする\n",
        "- ソースコードから TensorFlow をビルドする\n",
        "\n",
        "ビルド済みの TensorFlow バイナリを使って TensorFlow をインストールした場合は、SavedModel CLI はすでにシステムのパス名 `bin/saved_model_cli` にインストールされています。\n",
        "\n",
        "ソースコードから TensorFlow をビルドした場合は、さらに次のコマンドを実行して `saved_model_cli` をビルドする必要があります。\n",
        "\n",
        "```\n",
        "$ bazel build tensorflow/python/tools:saved_model_cli\n",
        "```\n",
        "\n",
        "### コマンドの概要\n",
        "\n",
        "SavedModel CLI は、SavedModel に使用する次の 2 つのコマンドをサポートしています。\n",
        "\n",
        "- `show` - SavedModel で利用できる計算を表示します。\n",
        "- `run` - SavedModel の計算を実行します。\n",
        "\n",
        "### `show` コマンド\n",
        "\n",
        "SavedModel には、1 つ以上のモデルバリアント（厳密には `v1.MetaGraphDef`）が含まれており、タグセットで識別されます。モデルを提供する上で、各モデルバリアントに含まれる `SignatureDef` の種類やその入力と出力について悩むことがあるかもしれません。そのようなときに `show` コマンドを使用すれば、SavedModel の中身を階層的に調べることができます。次はその構文を示しています。\n",
        "\n",
        "```\n",
        "usage: saved_model_cli show [-h] --dir DIR [--all] [--tag_set TAG_SET] [--signature_def SIGNATURE_DEF_KEY]\n",
        "```\n",
        "\n",
        "たとえば、次のコマンドでは、SavedModel 内の利用可能なすべてタグセットが表示されます。\n",
        "\n",
        "```\n",
        "$ saved_model_cli show --dir /tmp/saved_model_dir The given SavedModel contains the following tag-sets: serve serve, gpu\n",
        "```\n",
        "\n",
        "次のコマンドでは、タグセット当たりの利用可能なすべての `SignatureDef` キーが表示されます。\n",
        "\n",
        "```\n",
        "$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve The given SavedModel `MetaGraphDef` contains `SignatureDefs` with the following keys: SignatureDef key: \"classify_x2_to_y3\" SignatureDef key: \"classify_x_to_y\" SignatureDef key: \"regress_x2_to_y3\" SignatureDef key: \"regress_x_to_y\" SignatureDef key: \"regress_x_to_y2\" SignatureDef key: \"serving_default\"\n",
        "```\n",
        "\n",
        "タグセット内に*複数*のタグが存在する場合は、次の例のように、すべてのタグをカンマ区切りで指定する必要があります。\n",
        "\n",
        "<pre>$ saved_model_cli show --dir /tmp/saved_model_dir --tag_set serve,gpu</pre>\n",
        "\n",
        "特定の `SignatureDef` に対するすべての入力と出力の TensorInfo を表示するには、 `signature_def` オプションに `SignatureDef` キーを渡します。これは、後で計算グラフを実行する際の入力テンソルのテンソルキー値、dtype、および形状を知るうえで非常に役立ちます。次に例を示します。\n",
        "\n",
        "```\n",
        "$ saved_model_cli show --dir \\\n",
        "/tmp/saved_model_dir --tag_set serve --signature_def serving_default\n",
        "The given SavedModel SignatureDef contains the following input(s):\n",
        "  inputs['x'] tensor_info:\n",
        "      dtype: DT_FLOAT\n",
        "      shape: (-1, 1)\n",
        "      name: x:0\n",
        "The given SavedModel SignatureDef contains the following output(s):\n",
        "  outputs['y'] tensor_info:\n",
        "      dtype: DT_FLOAT\n",
        "      shape: (-1, 1)\n",
        "      name: y:0\n",
        "Method name is: tensorflow/serving/predict\n",
        "```\n",
        "\n",
        "SavedModel で利用できるすべての情報を表示するには、次の例のように `--all` オプションを使用します。\n",
        "\n",
        "<pre>$ saved_model_cli show --dir /tmp/saved_model_dir --all<br>MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:<br><br>signature_def['classify_x2_to_y3']:<br>  The given SavedModel SignatureDef contains the following input(s):<br>    inputs['inputs'] tensor_info:<br>        dtype: DT_FLOAT<br>        shape: (-1, 1)<br>        name: x2:0<br>  The given SavedModel SignatureDef contains the following output(s):<br>    outputs['scores'] tensor_info:<br>        dtype: DT_FLOAT<br>        shape: (-1, 1)<br>        name: y3:0<br>  Method name is: tensorflow/serving/classify<br><br>...<br><br>signature_def['serving_default']:<br>  The given SavedModel SignatureDef contains the following input(s):<br>    inputs['x'] tensor_info:<br>        dtype: DT_FLOAT<br>        shape: (-1, 1)<br>        name: x:0<br>  The given SavedModel SignatureDef contains the following output(s):<br>    outputs['y'] tensor_info:<br>        dtype: DT_FLOAT<br>        shape: (-1, 1)<br>        name: y:0<br>  Method name is: tensorflow/serving/predict</pre>\n",
        "\n",
        "### `run` コマンド\n",
        "\n",
        "グラフ計算を実行して、入力を渡してから出力を表示（または保存）するには `run` コマンドを呼び出します。次に構文を示します。\n",
        "\n",
        "```\n",
        "usage: saved_model_cli run [-h] --dir DIR --tag_set TAG_SET --signature_def                            SIGNATURE_DEF_KEY [--inputs INPUTS]                            [--input_exprs INPUT_EXPRS]                            [--input_examples INPUT_EXAMPLES] [--outdir OUTDIR]                            [--overwrite] [--tf_debug]\n",
        "```\n",
        "\n",
        "`run` コマンドでは、次の 3 つの方法でモデルに入力を渡すことができます。\n",
        "\n",
        "- `--inputs` オプション: ファイルに numpy ndarray を渡すことができます。\n",
        "- `--input_exprs` オプション: Python 式を渡すことができます。\n",
        "- `--input_examples` オプション: `tf.train.Example` を渡すことができます。\n",
        "\n",
        "#### `--inputs`\n",
        "\n",
        "ファイルに入力データを渡すには、次のような形式で `--inputs` オプションを指定します。\n",
        "\n",
        "```bsh\n",
        "--inputs <INPUTS>\n",
        "```\n",
        "\n",
        "上記の *INPUTS* は、次のいずれかの形式です。\n",
        "\n",
        "- `<input_key>=<filename>`\n",
        "- `<input_key>=<filename>[<variable_name>]`\n",
        "\n",
        "複数の *INPUTS* を渡すことができます。複数の入力（INPUTS）を渡す場合は、セミコロン区切りで *INPUTS* を指定します。\n",
        "\n",
        "`saved_model_cli` は `numpy.load` を使用して *filename* を読み込みます。*filename* は、次のいずれかの形式です。\n",
        "\n",
        "- `.npy`\n",
        "- `.npz`\n",
        "- ピクル形式\n",
        "\n",
        "`.npy` ファイルには、必ず numpy ndarray が含まれます。そのため、`.npy` ファイルから読み込む場合、コンテンツは指定された入力テンソルに直接割り当てられます。その `.npy` ファイルで *variable_name* を指定すると、*variable_name* は無視され、警告が発行されます。\n",
        "\n",
        "`.npz`（zip）ファイルから読み込む場合、任意で*variable_name* を指定して zip ファイル内の変数を識別し、入力テンソルキーに読み込むことができます。*variable_name* を指定しない場合は、SavedModel CLI は、ファイルが 1 つだけ zip ファイルに含まれていることを確認し、指定された入力テンソルキーにそれを読み込みます。\n",
        "\n",
        "ピクル形式から読み込む際に `variable_name` が大かっこで指定されていない場合、そのピクルファイルに含まれるもの何であれ、指定された入力テンソルキーに渡されます。そうでない場合、SavedModel CLI はピクルファイルにディクショナリが保存されているとみなし、*variable_name* に対応する値が使用されます。\n",
        "\n",
        "#### `--input_exprs`\n",
        "\n",
        "Python 式で入力を渡すには、`--input_exprs` オプションを指定します。これは、データファイルが手元にない場合に、モデルの `SignatureDef` の dtype と形状に一致する何らかの単純な入力を使ってサニティーチェックを実施する場合に便利です。次に例を示します。\n",
        "\n",
        "```bsh\n",
        "`<input_key>=[[1],[2],[3]]`\n",
        "```\n",
        "\n",
        "Python 式のほか、次のように numpy 関数を渡すこともできます。\n",
        "\n",
        "```bsh\n",
        "`<input_key>=np.ones((32,32,3))`\n",
        "```\n",
        "\n",
        "（`numpy` モジュールはすでに `np` として使用できるようになっていることに注意してください。）\n",
        "\n",
        "#### `--input_examples`\n",
        "\n",
        "`tf.train.Example` を入力として渡すには、`--input_examples` オプションを指定します。各入力キーに対し、ディクショナリのリストを取り、各ディクショナリは `tf.train.Example` のインスタンスです。ディクショナリキーは特徴量であり、値は各特徴量の値リストです。次に例を示します。\n",
        "\n",
        "```bsh\n",
        "`<input_key>=[{\"age\":[22,24],\"education\":[\"BS\",\"MS\"]}]`\n",
        "```\n",
        "\n",
        "#### 出力を保存する\n",
        "\n",
        "デフォルトでは、SavedModel CLI は出力を stdout に書き込みます。ディレクトリを `--outdir` オプションに渡すと、出力は あるディレクトリ配下に `.npy` ファイルとして保存されます。ファイル名には出力テンソルキーに因んだ名前が付けられます。\n",
        "\n",
        "既存の出力ファイルを上書きするには、`--overwrite` を使用してください。\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "saved_model.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
